from LSTM_casadi_mz1 import getRNNmodel_numpy_mz1
from LSTM_casadi_mz2 import getRNNmodel_numpy_mz2
from LSTM_casadi_mz3 import getRNNmodel_numpy_mz3
import numpy as np
sequenceLength = 5
data_min_mz1 = np.loadtxt("./core_validation/Weights/data_min_mz1_vrd_lai.txt")
data_max_mz1 = np.loadtxt("./core_validation/Weights/data_max_mz1_vrd_lai.txt")
data_min_mz2 = np.loadtxt("./core_validation/Weights/data_min_mz2_vrd_lai.txt")
data_max_mz2 = np.loadtxt("./core_validation/Weights/data_max_mz2_vrd_lai.txt")
data_min_mz3 = np.loadtxt("./core_validation/Weights/data_min_mz3_vrd_lai.txt")
data_max_mz3 = np.loadtxt("./core_validation/Weights/data_max_mz3_vrd_lai.txt")

def simulate_lstm_models(initial_states_mz1, initial_states_mz2, initial_states_mz3, irrig_rates, crop_coeffs, ref_evaps, rooting_depth, leaf_area_index, time_steps):    
    x_mz1 = []
    x_mz2 = []
    x_mz3 = []
    #Set the initial condition
    for i in range(len(initial_states_mz1)):
        x_mz1.append(initial_states_mz1[i])
        x_mz2.append(initial_states_mz2[i])
        x_mz3.append(initial_states_mz3[i])
        pass
    
    for j in range(time_steps-sequenceLength+1):
        x_current_mz1 = np.array(x_mz1[j:j+sequenceLength])
        x_current_mz2 = np.array(x_mz2[j:j+sequenceLength])
        x_current_mz3 = np.array(x_mz3[j:j+sequenceLength])
        u_current  = irrig_rates[j:j+sequenceLength]
        kc_current = crop_coeffs[j:j+sequenceLength]
        et_current = ref_evaps[j:j+sequenceLength]
        rd_current = rooting_depth[j:j+sequenceLength]
        lai_current = leaf_area_index[j:j+sequenceLength]
        current_state_mz1 = getRNNmodel_numpy_mz1(x_current_mz1, u_current, kc_current, et_current, rd_current, lai_current)
        current_state_mz2 = getRNNmodel_numpy_mz2(x_current_mz2, u_current, kc_current, et_current, rd_current, lai_current)
        current_state_mz3 = getRNNmodel_numpy_mz3(x_current_mz3, u_current, kc_current, et_current, rd_current, lai_current)
        x_mz1.append(current_state_mz1)
        x_mz2.append(current_state_mz2)
        x_mz3.append(current_state_mz3)
    pass 
    x_mz1_unscaled = np.array(x_mz1)*(data_max_mz1[0]-data_min_mz2[0]) + data_min_mz1[0]
    x_mz2_unscaled = np.array(x_mz2)*(data_max_mz2[0]-data_min_mz2[0]) + data_min_mz2[0]
    x_mz3_unscaled = np.array(x_mz3)*(data_max_mz3[0]-data_min_mz3[0]) + data_min_mz3[0] 
    return x_mz1_unscaled, x_mz2_unscaled, x_mz3_unscaled
